import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import MultipleLocator
from scipy import stats

params = {
    "axes.labelsize": 22,
    "font.size": 22,
    "legend.fontsize": 20,
    "xtick.labelsize": 20,
    "ytick.labelsize": 20,
    "grid.color": "k",
    "grid.linestyle": ":",
    "grid.linewidth": 1,
}
mpl.rcParams.update(params)

country_code = pd.read_csv(
    "https://raw.githubusercontent.com/lukes/ISO-3166-Countries-with-Regional-Codes/refs/heads/master/all/all.csv",
    na_filter=False,
)
dict_country_name = dict(zip(country_code["alpha-2"], country_code["name"]))
dict_iso3_to_iso2 = dict(zip(country_code["alpha-3"], country_code["alpha-2"]))

estimate = pd.read_csv("international_migration_flow.csv", na_filter=False)
estimate_2019 = (
    estimate[estimate["migration_month"].str.contains("2019")]
    .groupby(["country_from", "country_to"], as_index=False)
    .agg({"num_migrants": "sum"})
)

estimate_three_years = estimate.copy()
estimate_three_years["migration_year"] = estimate_three_years.apply(
    lambda x: x["migration_month"][:4], axis=1
)
estimate_three_years = estimate_three_years.groupby(
    ["country_from", "country_to", "migration_year"], as_index=False
).agg({"num_migrants": "sum"})

# Note this flow is detected based on New Zealand government’s definition:
# an individual to reside in the country for 12 months in a 16 month period to be a migrant
# rather than the 6 months in a 12 month period used by most nations.
estimate_nz_monthly = pd.read_csv("migration_flow_nz.csv", na_filter=False)
estimate_nz = estimate_nz_monthly.groupby(
    ["migration_year", "country_from", "country_to"], as_index=False
).agg({"num_migrants": "sum"})

# New Zealand
# download data from https://www.stats.govt.nz/information-releases/international-migration-december-2023/
gt = pd.read_csv(
    "international-migration-December-2023-citizenship-by-visa-by-country-of-last-permanent-residence.csv"
)
gt_filtered = gt[(gt.citizenship == "TOTAL") & (gt.visa == "TOTAL")]
gt_filtered = gt_filtered[
    (gt_filtered.status == "Final") | (gt_filtered["year_month"].str.contains("2022"))
]

gt_2019 = gt_filtered[gt_filtered["year_month"].str.contains("2019")]
gt_2020 = gt_filtered[gt_filtered["year_month"].str.contains("2020")]
gt_2021 = gt_filtered[gt_filtered["year_month"].str.contains("2021")]
gt_2022 = gt_filtered[gt_filtered["year_month"].str.contains("2022")]
gt_agg_2019 = (
    gt_2019.groupby("country_of_residence")["estimate"].sum().reset_index(name="gt_num")
)
gt_agg_2019 = gt_agg_2019.rename(columns={"country_of_residence": "country_from"})

# data from https://gist.github.com/tadast/8827699
country_iso2_df = pd.read_csv(
    "https://gist.githubusercontent.com/tadast/8827699/raw/61b2107766d6fd51e2bd02d9f78f6be081340efc/countries_codes_and_coordinates.csv",
    na_filter=False,
)
country_iso2_df = country_iso2_df[["Country", "Alpha-2 code"]]
country_iso2_df["iso2"] = country_iso2_df.apply(
    lambda x: x["Alpha-2 code"].replace('"', "").replace(" ", ""), axis=1
)
name2iso2 = dict(zip(country_iso2_df.Country, country_iso2_df.iso2))
name2iso2["China, People's Republic of"] = "CN"
name2iso2["United States of America"] = "US"
name2iso2["Hong Kong (Special Administrative Region)"] = "HK"
name2iso2["Czech Republic (to Sep 2018), Czechia (from Oct 2018)"] = "CZ"
name2iso2["Iran"] = "IR"
name2iso2["Laos"] = "LA"
name2iso2["Macau (Special Administrative Region)"] = "MO"
name2iso2["Kosovo"] = "XK"
name2iso2["Cote d'Ivoire"] = "CI"
name2iso2["Moldova"] = "MD"
name2iso2["North Macedonia"] = "MK"
name2iso2["St Lucia"] = "LC"
name2iso2["Syria"] = "SY"
name2iso2["Tanzania"] = "TZ"

gt_agg_2019["iso2"] = gt_agg_2019.apply(
    lambda x: (
        name2iso2[x["country_from"]] if x["country_from"] in name2iso2 else "NULL"
    ),
    axis=1,
)

estimate_to_nz = estimate_nz[estimate_nz.migration_year == 2019]
df_nz = estimate_to_nz[["country_from", "num_migrants"]].merge(
    gt_agg_2019[["iso2", "gt_num"]], left_on="country_from", right_on="iso2", how="left"
)


# monthly overall
gt_2019_to_2021 = pd.concat([gt_2019, gt_2020, gt_2021])
gt_2019_to_2021_month = gt_2019_to_2021[gt_2019_to_2021.country_of_residence == "TOTAL"]
gt_2019_to_2021_month = gt_2019_to_2021_month.sort_values(by="year_month")

_estimate = estimate_nz_monthly[
    estimate_nz_monthly.migration_year.isin([2019, 2020, 2021])
]
_estimate = _estimate.groupby("migration_month", as_index=False)["num_migrants"].sum()
df_nz_monthly = _estimate.merge(
    gt_2019_to_2021_month, left_on="migration_month", right_on="year_month", how="inner"
)
df_nz_monthly = df_nz_monthly.sort_values(by="migration_month").reset_index()


# monthly from India to New Zealand
gt_2019_to_2021_month_from_india = gt_2019_to_2021[
    gt_2019_to_2021.country_of_residence == "India"
]
gt_2019_to_2021_month_from_india = gt_2019_to_2021_month_from_india.sort_values(
    by="year_month"
)
_estimate = estimate_nz_monthly[
    estimate_nz_monthly.migration_year.isin([2019, 2020, 2021])
]
_estimate = _estimate[_estimate.country_from == "IN"]
df_nz_monthly_from_india = _estimate.merge(
    gt_2019_to_2021_month_from_india,
    left_on="migration_month",
    right_on="year_month",
    how="inner",
)
df_nz_monthly_from_india = df_nz_monthly_from_india.sort_values(
    by="migration_month"
).reset_index()


# Eurostat
# download data from https://ec.europa.eu/eurostat/databrowser/product/view/migr_imm5prv?lang=en&category=migr.migr_cit.migr_immi
eurostat = pd.read_csv("cleaned_eurostat_migration.csv", na_filter=False)
eurostat = eurostat[eurostat.year == 2019]

df_eu = eurostat.merge(
    estimate_2019,
    left_on=["origin", "destination"],
    right_on=["country_from", "country_to"],
    how="inner",
)
df_eu.sort_values(by="eurostat_migration")


# all together

fig = plt.figure()
fig.tight_layout()

fig.set_size_inches(11, 11)
fig.subplots_adjust(wspace=0.25, hspace=0.35)

ax1 = fig.add_subplot(2, 2, 1)
nz = df_nz.copy()
nz["gt_num"] = nz["gt_num"]
nz["num_migrants"] = nz["num_migrants"]

plt.scatter(nz.gt_num, nz.num_migrants, s=5, c="k")
plt.xticks(range(0, 31, 10), range(0, 31, 10))
plt.yticks(range(0, 31, 10), range(0, 31, 10))
plt.ylabel("Facebook estimate")
x_min = 0
x_max = 30000
plt.plot([x_min, x_max], [x_min, x_max], "--", color="steelblue")
nz = nz[~np.isnan(nz.gt_num)]
r = stats.pearsonr(nz.gt_num, nz.num_migrants)
plt.title(
    "(A)\nAnnual immigration flows to\nNew Zealand in 2019 ("
    + "$\it{r}$"
    + "={:.2f})".format(r[0]),
    fontsize=15,
)

plt.xlim(x_min + 1, x_max)
plt.ylim(x_min + 1, x_max)
plt.xlabel("Reported migration flow")
plt.xscale("log")
plt.yscale("log")
locmaj = mpl.ticker.LogLocator(base=10.0, subs=(1.0,), numticks=100)
ax1.xaxis.set_major_locator(locmaj)

locmin = mpl.ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=100)
ax1.xaxis.set_minor_locator(locmin)
ax1.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())


ax2 = fig.add_subplot(2, 2, 2)
eu = df_eu.copy()
eu["eurostat_migration"] = eu["eurostat_migration"]
eu["num_migrants"] = eu["num_migrants"]
plt.scatter(eu.eurostat_migration, eu.num_migrants, s=5, c="k")
x_min = 0
x_max = 120000
plt.plot([x_min, x_max], [x_min, x_max], "--", color="steelblue")
eu = eu[~np.isnan(eu.eurostat_migration)]
r = stats.pearsonr(eu.eurostat_migration, eu.num_migrants)
plt.title(
    "(B)\nAnnual immigration flows to\nEuropean countries in 2019 ("
    + "$\it{r}$"
    + "={:.2f})".format(r[0]),
    fontsize=15,
)
plt.xticks(range(0, 121, 40), range(0, 121, 40))
plt.yticks(range(0, 121, 40), range(0, 121, 40))
plt.xlim(x_min + 1, x_max)
plt.ylim(x_min + 1, x_max)
plt.xlabel("Reported migration flow")
plt.xscale("log")
plt.yscale("log")
locmaj = mpl.ticker.LogLocator(base=10.0, subs=(1.0,), numticks=100)
ax2.xaxis.set_major_locator(locmaj)

locmin = mpl.ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=100)
ax2.xaxis.set_minor_locator(locmin)
ax2.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())


first_two_years = df_nz_monthly_from_india[
    df_nz_monthly_from_india.migration_year <= 2020
]
ax4 = fig.add_subplot(2, 1, 2)
india = first_two_years.copy()
india["estimate"] = india["estimate"] / 1000
india["num_migrants"] = india["num_migrants"] / 1000
plt.plot(
    range(len(india)), india.estimate, "-o", c="k", label="Reported migration flow"
)
plt.plot(
    range(len(india)), list(india.num_migrants), "-o", c="g", label="Facebook estimate"
)

ax4.xaxis.set_major_locator(MultipleLocator(3))
ax4.xaxis.set_minor_locator(MultipleLocator(1))
ax4.tick_params(which="major", length=6)
ax4.tick_params(which="minor", length=3)
ax4.set_xticklabels(
    [
        "",
        "Jan 2019",
        "Apr 2019",
        "Jul 2019",
        "Oct 2019",
        "Jan 2020",
        "Apr 2020",
        "Jul 2020",
        "Oct 2020",
    ],
    minor=False,
    fontsize=15,
)

plt.legend(fontsize=15)
plt.title(
    "(C)\nMonthly migration flows from India to New Zealand",
    fontsize=15,
)
plt.ylabel("# Migrants (K)")
plt.xlim(0, len(india) - 1)
plt.ylim(-0.3, 5.5)
plt.yticks(range(6), range(6))
plt.gca().yaxis.grid(True, linestyle="-", c="lightgray")
plt.savefig(
    "figures/fig4.pdf",
    facecolor="white",
    transparent=False,
    bbox_inches="tight",
    dpi=300,
)
